"""
A rule to compile a C++ file to a header containing LLVM IR.
"""

load("@rules_cc//cc:find_cc_toolchain.bzl", "find_cc_toolchain", "use_cc_toolchain")
load("@rules_cc//cc/common:cc_common.bzl", "cc_common")
load("@rules_cc//cc/common:cc_info.bzl", "CcInfo")
load("//xla/tsl:package_groups.bzl", "DEFAULT_LOAD_VISIBILITY")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

visibility(DEFAULT_LOAD_VISIBILITY)

def to_camel_case(s):
    """Converts a snake_case or kebab-case string to CamelCase."""
    s_with_underscores = s.replace("-", "_")
    return "".join([p.capitalize() for p in s_with_underscores.split("_")])

def _cc_ir_header_impl(ctx):
    """Rule implementation that generates IR for multiple features and embeds them in a header."""
    cc_toolchain = find_cc_toolchain(ctx)
    feature_configuration = cc_common.configure_features(
        ctx = ctx,
        cc_toolchain = cc_toolchain,
    )
    compilation_contexts = [dep[CcInfo].compilation_context for dep in ctx.attr.deps]
    output_header = ctx.outputs.out_header

    ir_file = ctx.actions.declare_file("{}.ll".format(ctx.label.name))

    cxx_flags = [
        "-S",
        "-emit-llvm",
        "-O3",
        "-DEIGEN_VECTORIZE_GENERIC",
    ]
    compilation_outputs = cc_common.compile(
        actions = ctx.actions,
        feature_configuration = feature_configuration,
        cc_toolchain = cc_toolchain,
        srcs = ctx.files.src,
        compilation_contexts = compilation_contexts,
        cxx_flags = cxx_flags,
        name = "{}_compiler".format(ctx.label.name),
    )

    # Copy the compiler's output to our declared intermediate file.
    if len(compilation_outputs[1].pic_objects) > 0:
        temp_ir_output = compilation_outputs[1].pic_objects[0]
    else:
        temp_ir_output = compilation_outputs[1].objects[0]
    ctx.actions.run_shell(
        inputs = [temp_ir_output],
        outputs = [ir_file],
        command = "cp {} {}".format(temp_ir_output.path, ir_file.path),
        mnemonic = "CopyLLVMIR",
    )

    # Prepare the C++ variable definition for the header.
    ir_definition = 'inline constexpr char k{base_name}Ir[] = R"IR($(cat {input}))IR";'.format(
        base_name = to_camel_case(ctx.attr.base_name),
        input = ir_file.path,
    )

    # Generate the final C++ header file.
    ctx.actions.run_shell(
        inputs = [ir_file],
        outputs = [output_header],
        mnemonic = "EmbeddingLLVMIR",
        command = """
cat <<EOF > {output}
#pragma once

// Generated by cc_ir_header rule. DO NOT EDIT.

namespace {namespace} {{

{defs}

}} // namespace {namespace}
EOF
""".format(
            output = output_header.path,
            defs = ir_definition,
            namespace = ctx.attr.namespace,
        ),
        progress_message = "Embedding LLVM IR into header for %s" % ctx.label,
    )
    compilation_context = cc_common.create_compilation_context(headers = depset([output_header]))
    cc_info = CcInfo(compilation_context = compilation_context)

    return [DefaultInfo(files = depset([output_header])), cc_info]

_cc_ir_header_rule = rule(
    implementation = _cc_ir_header_impl,
    attrs = {
        "src": attr.label(
            allow_single_file = True,
            mandatory = True,
            doc = "The C++ source file to compile.",
        ),
        "deps": attr.label_list(providers = [CcInfo]),
        "out_header": attr.output(
            mandatory = True,
            doc = "The output header file.",
        ),
        "base_name": attr.string(
            mandatory = True,
            doc = "The base name of the generated IR variables.",
        ),
        "namespace": attr.string(
            default = "llvm_ir",
            doc = "The C++ namespace for the generated IR variables.",
        ),
        "_cc_toolchain": attr.label(default = "@bazel_tools//tools/cpp:current_cc_toolchain"),
    },
    toolchains = use_cc_toolchain(),
    fragments = ["cpp"],
)

def cc_ir_header(name, src, deps, **kwargs):
    """A macro that generates an IR header and wraps it in a cc_library.

    Args:
      name: The name of the generated cc_library.
      src: The C++ source file to compile.
      deps: The C++ dependencies of the source file.
      **kwargs: Additional arguments to pass to the generated cc_library.
    """
    out_header = name + ".h"
    generator_name = name + "_generator"

    _cc_ir_header_rule(
        base_name = name,
        name = generator_name,
        tags = ["manual"],
        src = src,
        deps = deps,
        out_header = out_header,
        # copybara_removed compatible_with = ["//buildenv/target:non_prod"],
        **kwargs
    )

    cc_library(
        name = name,
        hdrs = [":" + out_header],
        **kwargs
    )
